#ifndef HELPER_CUH
#define HELPER_CUH

#include <cstdint>
#include <sstream>
#include <stdio.h>
#include <type_traits>

template <typename T> __device__ inline const char *get_format();

template <> __device__ inline const char *get_format<int>() { return "%d, "; }

template <> __device__ inline const char *get_format<unsigned int>() {
  return "%u, ";
}

template <> __device__ inline const char *get_format<uint64_t>() {
  return "%lu, ";
}

template <typename T> __global__ void print_debug_kernel(const T *src, int N) {
  for (int i = 0; i < N; i++) {
    printf(get_format<T>(), src[i]);
  }
}

template <>
__global__ inline void print_debug_kernel(const __uint128_t *src, int N) {
  for (int i = 0; i < N; i++) {
    uint64_t low = static_cast<uint64_t>(src[i]);
    uint64_t high = static_cast<uint64_t>(src[i] >> 64);
    printf("(%llu, %llu), ", high, low);
  }
}

template <>
__global__ inline void print_debug_kernel(const double2 *src, int N) {
  for (int i = 0; i < N; i++) {
    printf("(%lf, %lf), ", src[i].x, src[i].y);
  }
}
template <typename T> void print_debug(const char *name, const T *src, int N) {
  printf("%s: ", name);
  cudaDeviceSynchronize();
  print_debug_kernel<<<1, 1>>>(src, N);
  cudaDeviceSynchronize();
  printf("\n");
}

template <typename T>
__global__ void print_body_kernel(T *src, int N, int lwe_dimension, T delta) {
  for (int i = 0; i < N; i++) {
    T body = src[i * (lwe_dimension + 1) + lwe_dimension];
    T clear = body / delta;
    printf("(%lu, %lu), ", body, clear);
  }
}

template <typename T>
void print_body(const char *name, T *src, int n, int lwe_dimension, T delta) {
  printf("%s: ", name);
  cudaDeviceSynchronize();
  print_body_kernel<<<1, 1>>>(src, n, lwe_dimension, delta);
  cudaDeviceSynchronize();
  printf("\n");
}

template <typename Torus>
void print_2d_csv_to_file(const std::vector<Torus> &v, int col_size,
                          const char *fname) {
  FILE *fp = fopen(fname, "wt");
  for (int i = 0; i < v.size() / col_size; ++i) {
    for (int j = 0; j < col_size; ++j) {
      fprintf(fp, "%lu%c", v[i * col_size + j],
              (j == col_size - 1) ? '\n' : ',');
    }
  }
  fclose(fp);
}

template <typename Torus>
__host__ void dump_2d_gpu_to_file(const Torus *ptr, int row_size, int col_size,
                                  const char *fname_prefix, int rand_prefix,
                                  cudaStream_t stream, uint32_t gpu_index) {
  // #ifndef NDEBUG
  std::vector<Torus> buf_cpu(row_size * col_size);

  char fname[4096];
  snprintf(fname, 4096, "%s_%d_%d_%d.csv", fname_prefix, row_size, col_size,
           rand_prefix);

  cuda_memcpy_async_to_cpu((void *)&buf_cpu[0], ptr,
                           buf_cpu.size() * sizeof(Torus), stream, gpu_index);
  cuda_synchronize_device(gpu_index);
  print_2d_csv_to_file(buf_cpu, col_size, fname);
  // #endif
}

template <typename Torus>
__host__ void compare_2d_arrays(const Torus *ptr1, const Torus *ptr2,
                                int row_size, int col_size, cudaStream_t stream,
                                uint32_t gpu_index) {
  // #ifndef NDEBUG
  std::vector<Torus> buf_cpu1(row_size * col_size),
      buf_cpu2(row_size * col_size);
  ;
  cuda_memcpy_async_to_cpu((void *)&buf_cpu1[0], ptr1,
                           buf_cpu1.size() * sizeof(Torus), stream, gpu_index);
  cuda_memcpy_async_to_cpu((void *)&buf_cpu2[0], ptr2,
                           buf_cpu2.size() * sizeof(Torus), stream, gpu_index);
  cuda_synchronize_device(gpu_index);

  std::vector<uint32_t> non_matching_indexes;
  for (int i = 0; i < buf_cpu1.size(); ++i) {
    if (buf_cpu1[i] != buf_cpu2[i]) {
      non_matching_indexes.push_back(i);
    }
  }

  if (!non_matching_indexes.empty()) {
    std::stringstream ss;
    for (int i = 0; i < std::min(non_matching_indexes.size(), (size_t)10);
         ++i) {
      ss << "    difference at " << non_matching_indexes[i] << ": "
         << buf_cpu1[non_matching_indexes[i]] << " vs "
         << buf_cpu2[non_matching_indexes[i]] << " at index "
         << non_matching_indexes[i] << "\n";
    }
    GPU_ASSERT(non_matching_indexes.empty(),
               "Correctness error for matrices %d x %d: \n%s", row_size,
               col_size, ss.str().c_str());
  }
}

#endif
